import numpy as np
import torch
import os
from VDN import VDN


class Agents:
    def __init__(self, args):
        self.args = args
        self.policy = VDN(args)

    def select_actions(self, s, available_u, h, epsilon):
        o = torch.from_numpy(s).float().reshape(self.args.vec_env * self.args.n_agents, -1)
        a_u = torch.from_numpy(available_u).float().reshape(self.args.vec_env * self.args.n_agents, -1)
        h = h.reshape(self.args.rnn_layer, self.args.vec_env * self.args.n_agents, -1)

        if self.args.cuda:
            o = o.cuda()
            a_u = a_u.cuda()

        action, _, hidden_next = self.policy.critic_network(o, a_u, h, eps=epsilon)

        us = action.detach().cpu().reshape(self.args.vec_env, self.args.n_agents).numpy()
        hs = hidden_next.reshape(self.args.rnn_layer, self.args.vec_env, self.args.n_agents, -1)

        return us, hs

    def learn(self, transitions):
        torch_transitions = transitions
        for key in torch_transitions.keys():
            if self.args.cuda:
                torch_transitions[key] = torch.tensor(np.array(transitions[key]), dtype=torch.float32).cuda()
            else:
                torch_transitions[key] = torch.tensor(np.array(transitions[key]), dtype=torch.float32)

        critic_loss = self.policy.loss(torch_transitions)

        self.policy.train(critic_loss)
        return critic_loss

    def init_hidden(self):
        if self.args.cuda:
            return torch.zeros(self.args.rnn_layer, self.args.vec_env, self.args.n_agents, self.args.hidden_size, dtype=torch.float32).cuda()
        else:
            return torch.zeros(self.args.rnn_layer, self.args.vec_env, self.args.n_agents, self.args.hidden_size, dtype=torch.float32)
